# Copyright (c) 2025 The HuggingFace Team.
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: Apache-2.0 
#
# This file has been modified by Bytedance Ltd. and/or its affiliates on September 15, 2025.
#
# Original file was released under Apache License 2.0, with the full license text
# available at https://github.com/huggingface/finetrainers/blob/main/LICENSE.
#
# This modified file is released under the same license.


from enum import Enum
from typing import Type

from .models import ModelSpecification
from .models.cogvideox import CogVideoXModelSpecification
from .models.cogview4 import CogView4ControlModelSpecification, CogView4ModelSpecification
from .models.flux import FluxModelSpecification
from .models.hunyuan_video import HunyuanVideoModelSpecification
from .models.ltx_video import LTXVideoModelSpecification
from .models.wan import WanControlModelSpecification, WanModelSpecification


class ModelType(str, Enum):
    COGVIDEOX = "cogvideox"
    COGVIEW4 = "cogview4"
    FLUX = "flux"
    HUNYUAN_VIDEO = "hunyuan_video"
    LTX_VIDEO = "ltx_video"
    WAN = "wan"


class TrainingType(str, Enum):
    # SFT
    LORA = "lora"
    FULL_FINETUNE = "full-finetune"

    # Control
    CONTROL_LORA = "control-lora"
    CONTROL_FULL_FINETUNE = "control-full-finetune"

    # mot
    VIDEO_AS_PROMPT_MOT = "video-as-prompt-mot"


SUPPORTED_MODEL_CONFIGS = {
    # TODO(aryan): autogenerate this
    # SFT
    ModelType.COGVIDEOX: {
        TrainingType.LORA: CogVideoXModelSpecification,
        TrainingType.FULL_FINETUNE: CogVideoXModelSpecification,
        TrainingType.VIDEO_AS_PROMPT_MOT: CogVideoXModelSpecification,
    },
    ModelType.COGVIEW4: {
        TrainingType.LORA: CogView4ModelSpecification,
        TrainingType.FULL_FINETUNE: CogView4ModelSpecification,
        TrainingType.CONTROL_LORA: CogView4ControlModelSpecification,
        TrainingType.CONTROL_FULL_FINETUNE: CogView4ControlModelSpecification,
    },
    ModelType.FLUX: {
        TrainingType.LORA: FluxModelSpecification,
        TrainingType.FULL_FINETUNE: FluxModelSpecification,
    },
    ModelType.HUNYUAN_VIDEO: {
        TrainingType.LORA: HunyuanVideoModelSpecification,
        TrainingType.FULL_FINETUNE: HunyuanVideoModelSpecification,
    },
    ModelType.LTX_VIDEO: {
        TrainingType.LORA: LTXVideoModelSpecification,
        TrainingType.FULL_FINETUNE: LTXVideoModelSpecification,
    },
    ModelType.WAN: {
        TrainingType.LORA: WanModelSpecification,
        TrainingType.FULL_FINETUNE: WanModelSpecification,
        TrainingType.CONTROL_LORA: WanControlModelSpecification,
        TrainingType.CONTROL_FULL_FINETUNE: WanControlModelSpecification,
        TrainingType.VIDEO_AS_PROMPT_MOT: WanModelSpecification,
    },
}


def _get_model_specifiction_cls(model_name: str, training_type: str) -> Type[ModelSpecification]:
    if model_name not in SUPPORTED_MODEL_CONFIGS:
        raise ValueError(
            f"Model {model_name} not supported. Supported models are: {list(SUPPORTED_MODEL_CONFIGS.keys())}"
        )
    if training_type not in SUPPORTED_MODEL_CONFIGS[model_name]:
        raise ValueError(
            f"Training type {training_type} not supported for model {model_name}. Supported training types are: {list(SUPPORTED_MODEL_CONFIGS[model_name].keys())}"
        )
    return SUPPORTED_MODEL_CONFIGS[model_name][training_type]
